/*********************************************************************
Replication code for Systemic Discrimination Among Large U.S. Employers
Patrick M. Kline, Evan K. Rose, Christopher R. Walters
April, 2022

This code produces the estimates of heterogeneity in contact gaps across
various groups of jobs. The output of this code can be copied into the
second tables of the spreadsheets in /tables to reproduce the formatted
output reported in the paper.

It requires top-level directory be set to the replication folder using
the global below.
*********************************************************************/

global dir "/accounts/projects/pkline/randres/randres/replication"
capture restore
clear all
set seed 126312

* File for collecting results
tempfile allresults
clear
save `allresults', replace emptyok

* Number of BS reps
global bsreps = 500

* Function for estimating group effects in application-level data
capture program drop estfirmapps
program define estfirmapps, rclass
    * Arguments
    args group_var attribute estimator

    * Set argument for attribute, keeping jobs that got both offs 
    preserve
    bys job_id: egen nattr = nvals(`attribute')
    qui keep if nattr == 2
    gen D = `attribute' == 1

    * Recode firm id, code firm dummies and interactions
    egen f = group(`group_var')
    qui sum f, detail
    local F=r(max)
    foreach f of numlist 1/`F' {
        gen f_`f'=(f==`f')
        gen D_`f'=D*f_`f'
    }

    * Estimate and store results
    matrix res = J(`F',2,.)
    matrix ng = J(`F',1,1)

    * Fit model and record results
    qui `estimator' cb f_* D_* [iweight=w], cluster(job_id)
    foreach f of numlist 1/`F' {
        matrix res[`f',1] = _b[D_`f']
        matrix res[`f',2] = _se[D_`f']^2
        if "`group_var'" != "firm_id" {
            qui distinct firm_id if f == `f'
            matrix ng[`f',1] = r(ndistinct)
        }
    }

    return matrix ests = res
    return matrix nfirms = ng
    return scalar ngroups = `F'
    restore
end

* Function for estimating group effects in job-level data
capture program drop estfirmjobs
program define estfirmjobs, rclass
    * Arguments
    args group_var attribute estimator

    * Set argument for attribute, keeping jobs that got both offs 
    preserve
    gen fid = firm_id
    collapse (mean) cb w (first) fid, by(jid `group_var' `attribute')
    qui reshape wide cb, i(jid `group_var' w) j(`attribute')
    qui gen cb = cb1 - cb0
    qui drop if cb == .

    * Weights to get firm-weighted estimates
    bys fid: egen njobs = nvals(jid)
    gen double fweight = 1/njobs
    qui replace w = w * fweight

    * Recode firm id, code firm dummies and interactions
    egen f = group(`group_var')
    qui sum f, detail
    local F=r(max)
    foreach f of numlist 1/`F' {
        gen f_`f'=(f==`f')
    }

    * Estimate and store results
    matrix res = J(`F',2,.)
    matrix ng = J(`F',1,1)

    * Fit model and record results
    foreach f of numlist 1/`F' {
        qui `estimator' cb f_* [iweight=w] if f == `f', vce(robust) nocons
        matrix res[`f',1] = _b[f_`f']
        matrix res[`f',2] = _se[f_`f']^2
        if "`group_var'" != "firm_id" {
            qui total fweight if f == `f'
            matrix ng[`f',1] = _b[fweight]
        }
    }

    return matrix ests = res
    return matrix nfirms = ng
    return scalar ngroups = `F'
    restore
end

* Firm contact gaps estimation function
capture program drop estfirm
program define estfirm, rclass
    * Arguments
    args group_var attribute estimator

    if ("`estimator'" == "reg") {
        estfirmjobs `group_var' `attribute' `estimator'
    }
    else {
        estfirmapps `group_var' `attribute' `estimator'
    }

    mat res = r(ests)
    return matrix ests = res
    mat nfirms = r(nfirms)
    return matrix nfirms = nfirms
    local F = r(ngroups)
    return scalar ngroups = `F'
end

* Function for computing variance component based on normal approximation
capture program drop varnormal
program define varnormal, rclass
    mata: weights = st_matrix("r(nfirms)")
    mata: F = quadsum(weights)
    mata: ests = st_matrix("r(ests)")[,1]
    mata: mu = quadsum(ests:*weights) / sum(weights)
    mata: ests2 = ests:^2
    mata: se2 = st_matrix("r(ests)")[,2]
    mata: plug = quadsum(ests2:*weights)/F - mu^2
    mata: correction = quadsum(se2:*(weights:*(F:-weights))) / F^2
    mata: btwnvar = plug - correction
    mata: st_matrix("btwnvar",btwnvar)
    return scalar btwnvar = btwnvar[1,1]
    return scalar btwnsd = sqrt(btwnvar[1,1])
end

* Mata code to get ustat
cap mata: mata drop ustater()
mata
void ustater() {
    fsplitmeans = st_matrix("fsplitmeans")
    jweights = st_matrix("jweights")
    fweights = st_matrix("fweights")
    fobs = length(fweights)
    nf = sum(fweights)
    nobs = length(fsplitmeans[.,1])
    fids = J(nobs,fobs,0)
    k = 1
    for (i=1; i<=nobs; i++) {
        fids[i,k] = 1
        if (i < nobs) {
            if (fsplitmeans[i,1] != fsplitmeans[i+1,1]) {
                k++
            }
        }
    }
    weightsums = fids'*jweights
    fmeans = (fids'*(fsplitmeans[.,3]:*jweights)):/weightsums
    fmeans = fmeans:*fweights
    fmean_prod = fmeans*fmeans'
    fmean_prod = sum(fmean_prod - diag(fmean_prod)) / nf^2

    fsquares = J(fobs,1,0)
    for (k=1; k<=fobs; k++) {
        fjweights = select(jweights,fids[.,k])
        fjobs = select(fsplitmeans[.,3],fids[.,k]):*fjweights
        prod = fjobs * fjobs'
        wsum = fjweights * fjweights'
        fsquares[k] = sum(prod - diag(prod)) / sum(wsum - diag(wsum))
    }
    fsquares = fsquares:*(fweights:*(nf:-fweights))
    
    btwnvar = quadsum(fsquares)/nf^2 - fmean_prod
    st_matrix("btwnvar", btwnvar)
    }
end

* Mata code to construct quantiles
cap mata: drop get_qtile()
mata
real scalar get_qtile(real colvector x, real scalar q)
{
    real colvector tmp
    real scalar idx
    
    if ((q <= 0) | (q >= 1)) {
        _error(3300)
    }
    
    tmp = sort(x, 1)
    idx = rows(x)*q
    
    if (mod(idx, 1)) {
        return(tmp[ceil(idx)])
    }
    else {
        return(((tmp[idx] + tmp[idx + 1])/2))
    }
}
end

* Mata code to conduct Bai e.a. test
cap mata: drop baitest()
mata
real scalar baitest() {
    ests = st_matrix("r(ests)")
    tmp = ests[.,1]:/(ests[.,2]:^0.5)
    nj = length(tmp)
    Zj = J(nj,2,0)
    Zj[.,1] = tmp
    Tn = max(rowmax(Zj))
    alpha = 1

    B = 10000
    normal_maxes = colmax(rnormal(nj,B,0,1))'
    normals = rnormal(nj,B,0,1)

    do {
        alpha = alpha-0.001
        beta = alpha/10
        if (alpha < 0) {
            alpha = 0
            break
        }
        c1 = get_qtile(normal_maxes,1-beta)

        c2 = Zj
        c2[.,1] = c2[.,1]:+ c1
        c2 = rowmin(c2)
        c2 = colmax(normals + c2*J(1,B,1))'
        c2 = get_qtile(c2,1-alpha+beta)
        if (Tn < min(c2)) {
            alpha == 1
            break
        } 
    } while (Tn >= c2)
    return(alpha)
}
end

* Function for split-sample estimator of variance
capture program drop varustat
program define varustat, rclass
    args group_var attribute split

    * Use only firms with multiple splits
    preserve
    bys `group_var': egen nobs = nvals(`split')
    qui keep if nobs > 1

    * Unique firms with in a group_var
    bys `group_var': egen nfirms = nvals(firm_id)    
    gen fid = firm_id

    * First collapse to jobs
    collapse (mean) cb w (max) nfirms, by(jid fid `split' `group_var' `attribute')
    qui reshape wide cb, i(jid fid `split' `group_var' w) j(`attribute')
    qui gen dif = cb1 - cb0
    drop cb0 cb1
    qui drop if dif == .

    * Weights to get firm-weighted estimates
    bys fid: egen njobs = nvals(jid)
    gen double fweight = 1/njobs
    qui replace w = w * fweight

    * Then collapse by split
    collapse (mean) dif (rawsum) w fweight (max) nfirms [aw=w], by(`split' `group_var')
    sort `group_var' `split'
    mkmat `group_var' `split' dif, matrix(fsplitmeans)
    mkmat w, matrix(jweights)

    * Group var wweights
    collapse (max) nfirms (sum) fweight, by(`group_var')
    sort `group_var'
    mkmat fweight, matrix(fweights)

    * Get ustat estimate
    mata: ustater()
    local btwnvar = btwnvar[1,1]
    return scalar btwnvar = `btwnvar'
    return scalar btwnsd = sqrt(`btwnvar')
    restore
end

* Function for bootstrap 
capture program drop bootvarest
program define bootvarest, rclass
    args nboots group_var attribute estimator split

    * First get the original estimates
    qui replace w = 1
    if "`estimator'" == "ustat" {
        varustat `group_var' `attribute' `split'
    }
    else {
        estfirm `group_var' `attribute' `estimator'
        mat tmp = r(ests)
        return matrix ests = tmp
        return scalar ngroups = r(ngroups)

        * Chi2 
        mata: ests_orig = st_matrix("r(ests)")
        mata: chi2 = sum(((ests_orig[.,1]:- mean(ests_orig[.,1])):^2):/ests_orig[.,2])
        mata: st_matrix("chi2", chi2)
        return scalar chi2 = chi2[1,1]
        return scalar chi2p_analytic = chi2tail(r(ngroups), chi2[1,1])

        * Firm weighted mean
        mata: fmean = mean(ests_orig[.,1])
        mata: st_matrix("fmean", fmean)
        return scalar fmean = fmean[1,1]

        * Asymptotic
        varnormal
    }
    return scalar btwnvar = r(btwnvar)
    return scalar btwnsd = r(btwnsd)
    local btwnsd_orig = r(btwnsd)

    * Collect results
    matrix bsres = J(`nboots',4,.)
    foreach s of numlist 1/`nboots' {
        di "Working on bootstrap loop `s'"
        capture drop _tmp*
        qui bys job_id: gen double _tmp = -ln(uniform()) if _n == 1
        qui bys job_id: egen double _tmp2 = max(_tmp)
        qui replace w = _tmp2

        if "`estimator'" == "ustat" {
            varustat `group_var' `attribute' `split'
        }
        else {
            estfirm `group_var' `attribute' `estimator'
            mata: ests = st_matrix("r(ests)")
            mata: bschi2 = sum(((ests[.,1] - ests_orig[.,1]):^2):/ests[.,2])
            mata: st_matrix("bschi2", bschi2)
            matrix bsres[`s',3] = bschi2[1,1]

            * Firm weighted mean
            mata: fmean = mean(ests[.,1])
            mata: st_matrix("fmean", fmean)
            matrix bsres[`s',4] = fmean[1,1]

            varnormal
        }
        matrix bsres[`s',1] = r(btwnvar)
        matrix bsres[`s',2] = r(btwnsd)
        if r(btwnvar) < 0 {
            matrix bsres[`s',2] = 0
        }

    }
    * Return the standard deviations of both
    mata: bsvar = quadvariance(st_matrix("bsres")[.,1..2])
    mata: st_matrix("bsvar", bsvar)
    return scalar se_btwnvar = sqrt(bsvar[1,1])
    return scalar se_btwnsd = sqrt(bsvar[2,2])
    return scalar delta_btwnsd = 1/2 * sqrt(bsvar[1,1]) / `btwnsd_orig'

    * Get chi2 p value
    if "`estimator'" != "ustat" {
        mata: chi2p = mean(st_matrix("bsres")[.,3]:>chi2)
        mata: st_matrix("chi2p", chi2p)
        return scalar chi2p = chi2p[1,1]
        mata: bsmean_var = quadvariance(st_matrix("bsres")[.,4])
        mata: st_matrix("bsmean_var", bsmean_var)
        return scalar fmean_se = sqrt(bsmean_var[1,1])
    }
end

* Loop over balanced and unbalanced samples
foreach bal of numlist 0/1 {
    use ${dir}/data/data.dta, clear

    * Constructed variables
    bys firm_id: egen mean_cb = mean(cb)
    gen w = 1
    egen jid = group(firm_id job_id)

    * Run with balanced sample?
    keep if balanced >= `bal'

    * Collect all the estimators
    foreach group_var of varlist firm_id st sic_combined intermed soc3 {
        foreach attribute of varlist black female over40 lgbtq_club gender_neutral_pronouns {
            foreach method in "reg" "ustat" {
                foreach split of varlist job_id wave st {
                    if "`split'" != "job_id" & "`method'" != "ustat" {
                        continue
                    }
                    if "`split'" == "st" & "`group_var'" == "st" {
                        continue
                    }
                    if "`group_var'" != "firm_id" & "`split'" != "job_id" {
                        continue
                    }

                    di "Working on `group_var' `attribute' `method' `split'"
                    preserve
                    capture {
                        * Estimate variance
                        bootvarest ${bsreps} `group_var' `attribute' `method' `split'
                        local varest = r(btwnvar)
                        local varest_se = r(se_btwnvar)
                        local sdest = r(btwnsd)
                        local sdest_se = r(se_btwnsd)
                        local sdest_se_delta = r(delta_btwnsd)
                        local ngroups = r(ngroups)
                        local chi2 = r(chi2)
                        local chi2p = r(chi2p)
                        local chi2p_analytic = r(chi2p_analytic)
                        local fmean = r(fmean)
                        local fmean_se = r(fmean_se)

                        * Do bai tests and within estimate if appropriate if appropriate
                        if "`method'" == "reg" {
                            replace w = 1
                            estfirm `group_var' `attribute' reg
                            mata: ans = baitest()
                            mata: st_local("bai1", strofreal(ans))
                            gen tmp_attr = 1-`attribute'
                            estfirm `group_var' tmp_attr reg
                            mata: ans = baitest()
                            mata: st_local("bai2", strofreal(ans))
                        }

                        * Save output
                        clear
                        set obs 1
                        gen group_var = "`group_var'"
                        gen attribute = "`attribute'"
                        gen method = "`method'"
                        gen varest = `varest'
                        gen varest_se = `varest_se'
                        gen sdest = `sdest'
                        gen sdest_se = `sdest_se'
                        gen sdest_se_delta = `sdest_se_delta'
                        gen ngroups = `ngroups'
                        gen split = "`split'"
                        gen chi2 = `chi2'
                        gen chi2p = `chi2p'
                        gen chi2p_analytic = `chi2p_analytic'
                        gen balanced = `bal'
                        gen fmean = `fmean'
                        gen fmean_se = `fmean_se'

                        * Estimate the bai tests if the method is reg
                        if "`method'" == "reg" {
                            gen bai1 = `bai1'
                            gen bai0 = `bai2'
                        }

                        append using `allresults'
                        save `allresults', replace
                    }
                    list if _n == 1
                    restore
                }
            }
        }
    }
}

* Outsheet results
use `allresults', clear
order balanced group_var attribute  method  varest  varest_se  sdest   sdest_se    sdest_se_delta  ngroups split   chi2    chi2p   chi2p_analytic   bai1   bai0 
outsheet using ${dir}/dump/table4_etc.csv, comma replace


